{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Movie Reviews with bert-for-tf2 on TPU.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/kpe/bert-for-tf2/blob/master/examples/tpu_movie_reviews.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XgnJDqeiSPqc",
        "colab_type": "text"
      },
      "source": [
        "# Overview\n",
        "\n",
        "This colab notebook demonstrates how to fine-tune a \n",
        "BERT based sentiment classifier on the IMDB Movie Reviews \n",
        "dataset using a freely provided colab TPU.\n",
        "\n",
        "We'll be using the TensorFlow Keras API implementation of BERT from [kpe/bert-for-tf2](https://github.com/kpe/bert-for-tf2)\n",
        "and the pre-trained BERT weights from [google-research/bert](https://github.com/google-research/bert).\n",
        "Instead of fine-tunning all of BERT weights, we'll make use of the [adapter-BERT](https://arxiv.org/abs/1902.00751) architecture to fine-tune only a fraction of the weighs, while keeping the original BERT weights frozen.\n",
        "\n",
        "\n",
        "The main steps towards training a Keras model on a TPU in colab would be:\n",
        " - **Google Storage Bucket** - TPUs currently need write access to a Google Storage Bucket for loading training data or weights and storing model checkpoints.\n",
        " - **GCP Authentication** - once you hava a storage bucket, giving colab the authorization to use it, is realy easy.\n",
        " - **pre-trained BERT** - we have to also copy the pre-trained BERT weights to our storage bucket (because loading the pre-trained checkpoint needs list permissions)\n",
        " - **TFRecord** - to fully utilize the TPU power, we need to feed the training data in the most efficient way possible, for which we'd be using a TFRecordDataset by encoding our training examples into tfrecord files.\n",
        " - **TPU Training** - simlpy create a Keras model inside a `TPU Distribution Strategy` scope would then be enough for placing our model on the TPU ready for training.\n",
        " \n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aB_F9OQqKK_q",
        "colab_type": "text"
      },
      "source": [
        "# Storage Bucket Authentication\n",
        "\n",
        "You need to setup a storage bucket in GCP for storing and loading model weights and feeding data into the TPUs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AXsSMfLdLeLC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%tensorflow_version 2.x\n",
        "\n",
        "import os\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XkkQli7WKdvJ",
        "colab_type": "code",
        "outputId": "d79d54af-c793-44f7-a750-95d51015f67a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "BUCKET = 'YOUR_BUCKET_NAME' #@param {type:\"string\"}\n",
        "\n",
        "OUTPUT_DIR = 'PATH_WITHIN_THE_BUCKET'#@param {type:\"string\"}\n",
        "\n",
        "#@markdown Whether or not to clear/delete the directory and create a new one\n",
        "DO_DELETE = False #@param {type:\"boolean\"}\n",
        "\n",
        "\n",
        "OUTPUT_DIR = 'gs://{}/colab/{}'.format(BUCKET, OUTPUT_DIR)\n",
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "\n",
        "if DO_DELETE:\n",
        "  try:\n",
        "    tf.io.gfile.DeleteRecursively(OUTPUT_DIR)\n",
        "  except:\n",
        "    # Doesn't matter if the directory didn't exist\n",
        "    pass\n",
        "tf.io.gfile.makedirs(OUTPUT_DIR)\n",
        "print('***** Model output directory: {} *****'.format(OUTPUT_DIR))"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "***** Model output directory: gs://kpe/colab/tpu_movie_reviews *****\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dCpvgG0vwXAZ",
        "colab_type": "text"
      },
      "source": [
        "# Prerequisites"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qFI2_B8ffipb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install tqdm >> /dev/null\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hsZvic2YxnTz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "import math\n",
        "import datetime\n",
        "\n",
        "\n",
        "from tqdm import tqdm\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Evlk1N78HIXM",
        "colab_type": "code",
        "outputId": "42744ec8-e47c-4462-92e9-084983fe58cb",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "tf.__version__"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'2.2.0-rc2'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HqYo_14wJ_AY",
        "colab_type": "text"
      },
      "source": [
        "For accessing the TPU we need a `TPUClusterResolver` instance. Calling the constructor without parameters is enough in Colab:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TAdrQqEccIva",
        "colab_type": "code",
        "outputId": "ba313166-4ba4-41e8-b856-d79166f6706a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        }
      },
      "source": [
        "USE_TPU=True\n",
        "try:\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection\n",
        "  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
        "except Exception as ex:\n",
        "  print(ex)\n",
        "  USE_TPU=False\n",
        "\n",
        "print(\"        USE_TPU:\", USE_TPU)\n",
        "print(\"Eager Execution:\", tf.executing_eagerly())\n"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Running on TPU  ['10.124.147.138:8470']\n",
            "        USE_TPU: True\n",
            "Eager Execution: True\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cp5wfXDx5SPH",
        "colab_type": "text"
      },
      "source": [
        "So lets also pip install the [bert-for-tf2](https://github.com/kpe/bert-for-tf2) python package containing the Keras implementation of BERT."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jviywGyWyKsA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 477
        },
        "outputId": "4c7b3f40-73ad-4988-81d7-ce0afd452b63"
      },
      "source": [
        "!pip install --upgrade bert-for-tf2 params-flow sentencepiece #>> /dev/null"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting bert-for-tf2\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/ff/84/1bea6c34d38f3e726830d3adeca76e6e901b98cf5babd635883dbedd7ecc/bert-for-tf2-0.14.1.tar.gz (40kB)\n",
            "\r\u001b[K     |████████                        | 10kB 18.6MB/s eta 0:00:01\r\u001b[K     |████████████████▏               | 20kB 2.2MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 30kB 2.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 40kB 2.2MB/s \n",
            "\u001b[?25hCollecting params-flow\n",
            "  Downloading https://files.pythonhosted.org/packages/ac/0d/615c0d4aea541b4f47c761263809a02e160e7a2babd175f0ddd804776cf4/params-flow-0.8.0.tar.gz\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
            "\u001b[K     |████████████████████████████████| 1.0MB 8.3MB/s \n",
            "\u001b[?25hCollecting py-params>=0.9.6\n",
            "  Downloading https://files.pythonhosted.org/packages/2a/6d/d41be94cf328ef8f9739dfe4871b73b05aca74d42f841b6dde629af7507d/py-params-0.9.6.tar.gz\n",
            "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from params-flow) (1.18.2)\n",
            "Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from params-flow) (4.38.0)\n",
            "Building wheels for collected packages: bert-for-tf2, params-flow, py-params\n",
            "  Building wheel for bert-for-tf2 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for bert-for-tf2: filename=bert_for_tf2-0.14.1-cp36-none-any.whl size=30083 sha256=13297d0e1a31f9ec72a1e6f01494bbcdaebb7061a188081aae99c84bf9511537\n",
            "  Stored in directory: /root/.cache/pip/wheels/dd/f1/10/861fd7899727e4034293fb1dfef45b00f8cd476d21d3b3821e\n",
            "  Building wheel for params-flow (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for params-flow: filename=params_flow-0.8.0-cp36-none-any.whl size=15999 sha256=6f330cb61d67b69b8e525ddc304e678d30d9fb5ffdf0a90292eb572de7dfb989\n",
            "  Stored in directory: /root/.cache/pip/wheels/88/41/05/1a9955d1d01575bbd58aab76e22f8c7eeabba905d551576f43\n",
            "  Building wheel for py-params (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for py-params: filename=py_params-0.9.6-cp36-none-any.whl size=7090 sha256=d99f664f0e1bc04000110327e5ef46ad418d33d0fc9e0fd0cdc187b8c0f6b43a\n",
            "  Stored in directory: /root/.cache/pip/wheels/49/e8/e7/e953ff6a37f696ec894da30a547ee751d1270ed10b5d676c96\n",
            "Successfully built bert-for-tf2 params-flow py-params\n",
            "Installing collected packages: py-params, params-flow, bert-for-tf2, sentencepiece\n",
            "Successfully installed bert-for-tf2-0.14.1 params-flow-0.8.0 py-params-0.9.6 sentencepiece-0.1.85\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZtI7cKWDbUVc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "bd2fe84d-2b01-432c-80b3-4557612d1fd5"
      },
      "source": [
        "import params_flow as pf\n",
        "\n",
        "import bert\n",
        "from bert import BertModelLayer\n",
        "from bert.tokenization.bert_tokenization import FullTokenizer\n",
        "from bert import load_stock_weights, params_from_pretrained_ckpt\n",
        "\n",
        "\n",
        "print(\"bert-for-tf2\", bert.__version__)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bert-for-tf2 0.14.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9U4F7ZQY6sUP",
        "colab_type": "text"
      },
      "source": [
        "# The BERT Pre-Trained\n",
        "\n",
        "The original pre-trained BERT weights are available in a Google Storage Bucket at `gs://bert_models/`, but without the list permission needed by the TensorFlow APIs used for loading the weights from the pre-trained checkpoint, so we have to copy the pre-trained model to our own bucket:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lw_F488eixTV",
        "colab_type": "code",
        "outputId": "22c19107-e617-453e-a569-c3f3527a339c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 228
        }
      },
      "source": [
        "bert_ckpt_dir    = \"gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12\"\n",
        "#bert_ckpt_dir    = \"gs://bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16\"\n",
        "\n",
        "bert_ckpt_file   = os.path.join(bert_ckpt_dir, \"bert_model.ckpt\")\n",
        "bert_config_file = os.path.join(bert_ckpt_dir, \"bert_config.json\")\n",
        "bert_model_name  = os.path.basename(os.path.dirname(bert_ckpt_file))\n",
        "\n",
        "bert_ckpt_files = [\"bert_config.json\",\n",
        "                   \"bert_model.ckpt.data-00000-of-00001\",\n",
        "                   \"bert_model.ckpt.index\",\n",
        "                   \"bert_model.ckpt.meta\",\n",
        "                   \"vocab.txt\"]\n",
        "\n",
        "gs_bert_ckpt_dir = os.path.join(OUTPUT_DIR, \"bert_models\", bert_model_name)\n",
        "if not tf.io.gfile.exists(gs_bert_ckpt_dir):\n",
        "  cmd = \" \".join([os.path.join(bert_ckpt_dir, bert_file)\n",
        "                   for bert_file in bert_ckpt_files])\n",
        "  cmd = \"gsutil -m cp {} {}\".format(cmd, gs_bert_ckpt_dir)\n",
        "  !$cmd\n",
        "\n",
        "!gsutil ls $gs_bert_ckpt_dir"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Copying gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_config.json [Content-Type=application/json]...\n",
            "Copying gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001 [Content-Type=application/octet-stream]...\n",
            "/ [0 files][    0.0 B/  313.0 B]                                                \r/ [0 files][    0.0 B/420.0 MiB]                                                \rCopying gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.index [Content-Type=application/octet-stream]...\n",
            "/ [0 files][    0.0 B/420.0 MiB]                                                \rCopying gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/bert_model.ckpt.meta [Content-Type=application/octet-stream]...\n",
            "Copying gs://bert_models/2018_10_18/uncased_L-12_H-768_A-12/vocab.txt [Content-Type=text/plain]...\n",
            "/ [5/5 files][421.1 MiB/421.1 MiB] 100% Done                                    \n",
            "Operation completed over 5 objects/421.1 MiB.                                    \n",
            "gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_config.json\n",
            "gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001\n",
            "gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt.index\n",
            "gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt.meta\n",
            "gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/vocab.txt\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fQiPKscKPpmT",
        "colab_type": "code",
        "outputId": "c2e7bf28-745d-46c8-f103-0cc3199dab90",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "bert_ckpt_dir    = gs_bert_ckpt_dir\n",
        "bert_ckpt_file   = os.path.join(bert_ckpt_dir, \"bert_model.ckpt\")\n",
        "bert_config_file = os.path.join(bert_ckpt_dir, \"bert_config.json\")\n",
        "\n",
        "print(\"Using BERT checkpoint from:\", bert_ckpt_dir)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using BERT checkpoint from: gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pmFYvkylMwXn",
        "colab_type": "text"
      },
      "source": [
        "# The IMDB Movie Review Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cx5Yard1w5oe",
        "colab_type": "text"
      },
      "source": [
        "Lets use [kpe/params-flow](https://github.com/kpe/params-flow) to lazy download and upack the IMDB Movie Review Dataset (`params_flow` has been already pip installed as a [kpe/bert-for-tf2](https://github.com/kpe/params-flow) dependency)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SziWodfBwy-N",
        "colab_type": "code",
        "outputId": "d357a321-d1ed-4c59-9a40-8f6df45e94d2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 193
        }
      },
      "source": [
        "fetched_file = pf.utils.fetch_url(\"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\", fetch_dir=\".data\")\n",
        "unpack_dir   = pf.utils.unpack_archive(fetched_file)\n",
        "data_dir     = os.path.join(unpack_dir, \"aclImdb\")\n",
        "\n",
        "!ls -la {data_dir}"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "aclImdb_v1.tar.gz: 84.1MB [00:02, 33.1MB/s]                            \n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "extracting to: .data/aclImdb_v1\n",
            "total 1732\n",
            "drwxr-xr-x 4 7297 1000   4096 Jun 26  2011 .\n",
            "drwxr-xr-x 3 root root   4096 Apr  6 15:14 ..\n",
            "-rw-r--r-- 1 7297 1000 903029 Jun 11  2011 imdbEr.txt\n",
            "-rw-r--r-- 1 7297 1000 845980 Apr 12  2011 imdb.vocab\n",
            "-rw-r--r-- 1 7297 1000   4037 Jun 26  2011 README\n",
            "drwxr-xr-x 4 7297 1000   4096 Apr 12  2011 test\n",
            "drwxr-xr-x 5 7297 1000   4096 Jun 26  2011 train\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qYwlvtPK08Ec",
        "colab_type": "text"
      },
      "source": [
        "# The TFRecords Conversion\n",
        "\n",
        "We now proceed with converting the raw dataset into TFRecord files by:\n",
        " - loading every dataset file\n",
        " - preprocessing each sample by removing the `</ br>` markup\n",
        " - encoding the target label to an integer (i.e. 0 or 1) and tokenizing the review text with the BERT tokenizer (a vocab file is provided with the pre-trained model)\n",
        " - serializing the so encoded examples into a tfrecord file"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "REOlJ8ipx8lW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from functools import partial\n",
        "from glob import glob\n",
        "from multiprocessing import Pool\n",
        "\n",
        "\n",
        "def load_sample(path):\n",
        "    \"\"\"Loads an IMDB Movie Reviews data sample from a file.\"\"\"\n",
        "    label   = path.split('/')[-2]\n",
        "    with open(path, \"r\") as f:\n",
        "        content = f.read()\n",
        "    return content, label\n",
        "    \n",
        "def preprocess_sample(content, label):\n",
        "    content = content.replace(\"<br />\", \" \")\n",
        "    return content, label\n",
        "    \n",
        "def encode_sample(content, label, tokenizer):\n",
        "    content = tokenizer.tokenize(content)\n",
        "    content = tokenizer.convert_tokens_to_ids(content)\n",
        "    label = int(label == \"pos\")\n",
        "    return content, label\n",
        "\n",
        "def serialize_example(token_ids, label):\n",
        "    feature = {\n",
        "        \"token_ids\": tf.train.Feature(int64_list=tf.train.Int64List(value=token_ids)),\n",
        "        \"label\":     tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))\n",
        "    }\n",
        "    proto = tf.train.Example(features=tf.train.Features(feature=feature))\n",
        "    return proto.SerializeToString()\n",
        "\n",
        "def to_tfrecord(file_path, tokenizer):\n",
        "    sample = load_sample(file_path)\n",
        "    sample = preprocess_sample(*sample)\n",
        "    sample = encode_sample(*sample, tokenizer=tokenizer)\n",
        "    sample = serialize_example(*sample)\n",
        "    return sample\n",
        "\n",
        "def convert_to_tfrecord_file(file_name, ds_dir, serializer_fn):\n",
        "    with tf.io.TFRecordWriter(file_name) as writer:    \n",
        "        all_files = glob(os.path.join(ds_dir, \"pos/*\"))\n",
        "        all_files += glob(os.path.join(ds_dir, \"neg/*\"))\n",
        "        with Pool() as pool:\n",
        "            protos = pool.imap_unordered(serializer_fn, all_files)\n",
        "            for proto in tqdm(protos, total=len(all_files)):\n",
        "                writer.write(proto)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "igZnhhTA3hmH",
        "colab_type": "text"
      },
      "source": [
        "The location of the preprocessed train and test tfrecord files:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VtvWU6630yPS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_tfrecord_file = os.path.join(OUTPUT_DIR, \"data\", \"train.tfrecord\")\n",
        "test_tfrecord_file  = os.path.join(OUTPUT_DIR, \"data\", \"test.tfrecord\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sNtdVOHT01ha",
        "colab_type": "text"
      },
      "source": [
        "We can now instantiate the BERT tokenizer and do the TFRecord conversion:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NyfPI4dayRXm",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "1acae4ec-3901-4e12-c399-9015ae8d875b"
      },
      "source": [
        "tokenizer = FullTokenizer(os.path.join(bert_ckpt_dir, \"vocab.txt\"))\n",
        "\n",
        "def serialize_to_tfrecord(ds_file):\n",
        "    return to_tfrecord(ds_file, tokenizer)\n",
        "        \n",
        "if not all([tf.io.gfile.exists(train_tfrecord_file),\n",
        "            tf.io.gfile.exists(test_tfrecord_file)]):\n",
        "  print(\"Preparing the [train, test].tfrecord files...\")\n",
        "  \n",
        "  convert_to_tfrecord_file(train_tfrecord_file, \n",
        "                           os.path.join(data_dir, \"train\"),\n",
        "                           serialize_to_tfrecord)\n",
        "  convert_to_tfrecord_file(test_tfrecord_file, \n",
        "                           os.path.join(data_dir, \"test\"),\n",
        "                           serialize_to_tfrecord)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Preparing the [train, test].tfrecord files...\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 25000/25000 [01:43<00:00, 242.68it/s]\n",
            "100%|██████████| 25000/25000 [01:37<00:00, 257.72it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P1tmUIc97AEg",
        "colab_type": "text"
      },
      "source": [
        "Reading the tfrecord files is done using a `TFRecordDataset`. For batching the data we must make sure all sequences have the same lenght, so we need to also trim and pad each example:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F9rORqVi5sii",
        "colab_type": "code",
        "outputId": "554f7d41-978e-4ecd-a10d-022f84199d8a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "def tfrecord_to_dataset(filenames):\n",
        "  ds = tf.data.TFRecordDataset(filenames)\n",
        "  feature_description = {\n",
        "      \"token_ids\":  tf.io.VarLenFeature(tf.int64),\n",
        "      \"label\":      tf.io.FixedLenFeature([], tf.int64, default_value=-1)\n",
        "  }\n",
        "\n",
        "  def parse_proto(proto):\n",
        "    example = tf.io.parse_single_example(proto, feature_description)\n",
        "    token_ids, label = example[\"token_ids\"], example[\"label\"]\n",
        "    token_ids = tf.sparse.to_dense(token_ids)\n",
        "    return token_ids, label\n",
        "\n",
        "  return ds.map(parse_proto)\n",
        "\n",
        "pad_id, cls_id, sep_id = tokenizer.convert_tokens_to_ids([\"[PAD]\", \"[CLS]\", \"[SEP]\"])\n",
        "print(\"pad cls sep:\", pad_id, cls_id, sep_id)\n",
        "\n",
        "def create_pad_example_fn(pad_len, \n",
        "                          pad_id=pad_id, \n",
        "                          cls_id=cls_id, \n",
        "                          sep_id=sep_id,\n",
        "                          trim_beginning=True):\n",
        "  def pad_example(x, label):\n",
        "    seq_len = pad_len - 2\n",
        "    x = x[-seq_len:] if trim_beginning else x[:seq_len]\n",
        "    x = tf.pad(x, [[0, seq_len - tf.shape(x)[-1]]], constant_values=pad_id)\n",
        "    x = tf.concat([[cls_id], x, [sep_id]], axis=-1)\n",
        "    x = tf.reshape(x, (pad_len,))\n",
        "    label = tf.reshape(label, ())\n",
        "    label = tf.cast(label, tf.float32) # TPU's don't support uint8 see https://github.com/tensorflow/datasets/issues/832\n",
        "    x = tf.cast(x, tf.float32)\n",
        "    return  x, label\n",
        "  return pad_example    \n"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "pad cls sep: 0 101 102\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R454eTyqIKnX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_tfrecord_file = os.path.join(OUTPUT_DIR, \"data\", \"train.tfrecord\")\n",
        "test_tfrecord_file  = os.path.join(OUTPUT_DIR, \"data\", \"test.tfrecord\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjlcx1E9-2ns",
        "colab_type": "text"
      },
      "source": [
        "Because BERT can handle up to 512 tokens, and because BERT computational and memory requirements scale quadratically with the input sequence length, we will have to trim the sequences to a shorter size, but let's first check the sequence length distribution, and how many examples would be affected, if we trim all sequences to 512 tokens:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lbQjy2Kh7rNC",
        "colab_type": "code",
        "outputId": "7cba5be2-04a7-41f8-8723-081481447831",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 580
        }
      },
      "source": [
        "import pandas as pd\n",
        "\n",
        "train_lens = []\n",
        "test_lens  = []\n",
        "\n",
        "count = 0\n",
        "\n",
        "def get_sequence_lengths(tfrecord_file):\n",
        "  res = []\n",
        "\n",
        "  def sample_seq_len(tok_id, lab):\n",
        "    return tf.shape(tok_id)[0]\n",
        "  \n",
        "  ds = tfrecord_to_dataset([tfrecord_file]).map(sample_seq_len).batch(128)\n",
        "  if tf.executing_eagerly():\n",
        "      for batch_of_seq_lens in ds:\n",
        "        res.extend(list(batch_of_seq_lens))\n",
        "  else:\n",
        "    it = tf.compat.v1.data.make_one_shot_iterator(ds)\n",
        "    seq_lens = it.get_next()\n",
        "    with tf.Session() as sess:\n",
        "      try:\n",
        "        while True:\n",
        "          res.extend(list(sess.run(seq_lens)))\n",
        "      except Exception as ex:\n",
        "        pass\n",
        "  return res\n",
        "\n",
        "\n",
        "def show_drop_count(lens, max_seq_len = 512, name=\"ds\"):\n",
        "    df = pd.DataFrame(lens)\n",
        "    df.hist(bins=100)\n",
        "    drop_count = df[df[0]>max_seq_len].shape[0]\n",
        "    print(\"{:>5s} drop count: {} of {} - {:5.2f}%\".format(name, drop_count, len(lens), 100*drop_count/len(lens)))\n",
        "\n",
        "train_lens = get_sequence_lengths(train_tfrecord_file)\n",
        "test_lens = get_sequence_lengths(test_tfrecord_file)\n",
        "\n",
        "    \n",
        "show_drop_count(train_lens, name=\"train\")\n",
        "show_drop_count(test_lens, name=\"test\")"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "train drop count: 3290 of 25000 - 13.16%\n",
            " test drop count: 3057 of 25000 - 12.23%\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAX1klEQVR4nO3df4xdZZ3H8ffH8ssw2JaFnXRLs61r\nXcOPtcKkxdWYGYil1D+KiWvKEmgRM66WjWbdDYPGBUWSuuuPLBExddu1qOvYRQmTUhZrZUL4A2mr\npT9gsSOUlUml0ZbioItb9rt/3GfwMt6Z+2Nu7z13ns8rubnnPOc55z7f3unnnnvOufcqIjAzszy8\nrt0DMDOz1nHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh75ZnSSdLeleSS9JelbSX7d7\nTGa1OqXdAzDrQHcCvwO6gSXA/ZIej4gD7R2WWXXyJ3LNaifpTOAYcGFE/DS1fQMYjYiBtg7OrAY+\nvGNWnzcDJ8YDP3kcuKBN4zGri0PfrD5dwIsT2o4DZ7VhLGZ1c+ib1WcMeMOEtjcAv27DWMzq5tA3\nq89PgVMkLS5reyvgk7jWEXwi16xOkgaBAD5I6eqdbcBf+uod6wTe0zer30eA1wNHgG8DH3bgW6fw\nnr6ZWUa8p29mlhGHvplZRhz6ZmYZceibmWWk0F+4ds4558TChQsbWvell17izDPPbO6A2sB1FIvr\nKJaZUgc0t5bdu3f/MiLOrbSs0KG/cOFCdu3a1dC6w8PD9Pb2NndAbeA6isV1FMtMqQOaW4ukZydb\n5sM7ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRqqGvqQzJD0m6XFJByR9\nOrV/XdIzkvak25LULkl3SBqRtFfSxWXbWiPpYLqtOXllmZlZJbV8Ivdl4LKIGJN0KvCIpAfSsn+I\niHsm9L8SWJxuy4C7gGWSzgZuAXoo/erQbklDEXGsGYU028KB+1+dPrT+PW0ciZlZ81Td04+SsTR7\narpN9csrq4C703qPAnMkzQOuALZHxNEU9NuBFdMbvpmZ1aOmX86SNAvYDbwJuDMibpL0deDtlN4J\n7AAGIuJlSVuB9RHxSFp3B3AT0AucERGfTe2fAn4bEZ+f8Fj9QD9Ad3f3JYODgw0VNjY2RldXV0Pr\nAuwbPf7q9EXzZze8nemabh1F4TqKxXUUTzNr6evr2x0RPZWW1fSFaxHxCrBE0hzgXkkXAjcDvwBO\nAzZQCvbPTHewEbEhbY+enp5o9AuIpvvlRWvLD+9c0/h2pmumfKGU6ygW11E8raqlrqt3IuIF4CFg\nRUQcTodwXgb+DViauo0CC8pWOy+1TdZuZmYtUsvVO+emPXwkvR54N/Bf6Tg9kgRcBexPqwwB16Wr\neC4FjkfEYeBBYLmkuZLmAstTm5mZtUgth3fmAZvTcf3XAVsiYqukH0o6FxCwB/ib1H8bsBIYAX4D\nXA8QEUcl3QbsTP0+ExFHm1eKmZlVUzX0I2Iv8LYK7ZdN0j+AdZMs2wRsqnOMZmbWJP5ErplZRhz6\nZmYZKfRv5LZa+adwzcxmIu/pm5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZm\nGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llpGroSzpD0mOS\nHpd0QNKnU/siST+SNCLpO5JOS+2np/mRtHxh2bZuTu1PSbriZBVlZmaV1bKn/zJwWUS8FVgCrJB0\nKfA54EsR8SbgGHBD6n8DcCy1fyn1Q9L5wGrgAmAF8BVJs5pZjJmZTa1q6EfJWJo9Nd0CuAy4J7Vv\nBq5K06vSPGn55ZKU2gcj4uWIeAYYAZY2pQozM6uJIqJ6p9Ie+W7gTcCdwD8Dj6a9eSQtAB6IiAsl\n7QdWRMRzadnPgGXArWmdb6b2jWmdeyY8Vj/QD9Dd3X3J4OBgQ4WNjY3R1dVV1zr7Ro9XbL9o/uyG\nxtAMjdRRRK6jWFxH8TSzlr6+vt0R0VNp2Sm1bCAiXgGWSJoD3Au8pSkjq/xYG4ANAD09PdHb29vQ\ndoaHh6l33bUD91dsP3RNY2NohkbqKCLXUSyuo3haVUtdV+9ExAvAQ8DbgTmSxl80zgNG0/QosAAg\nLZ8N/Kq8vcI6ZmbWArVcvXNu2sNH0uuBdwNPUgr/96Vua4D70vRQmict/2GUjiENAavT1T2LgMXA\nY80qxMzMqqvl8M48YHM6rv86YEtEbJX0BDAo6bPAT4CNqf9G4BuSRoCjlK7YISIOSNoCPAGcANal\nw0ZmZtYiVUM/IvYCb6vQ/jQVrr6JiP8B/mqSbd0O3F7/MM3MrBn8iVwzs4w49M3MMuLQNzPLiEPf\nzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQ\nNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4xUDX1JCyQ9JOkJSQckfTS13yppVNKe\ndFtZts7NkkYkPSXpirL2FaltRNLAySnJzMwmc0oNfU4AH4+IH0s6C9gtaXta9qWI+Hx5Z0nnA6uB\nC4A/AX4g6c1p8Z3Au4HngJ2ShiLiiWYUYmZm1VUN/Yg4DBxO07+W9CQwf4pVVgGDEfEy8IykEWBp\nWjYSEU8DSBpMfR36ZmYtooiovbO0EHgYuBD4O2At8CKwi9K7gWOSvgw8GhHfTOtsBB5Im1gRER9M\n7dcCyyLixgmP0Q/0A3R3d18yODjYUGFjY2N0dXXVtc6+0eNV+1w0f3ZD42lUI3UUkesoFtdRPM2s\npa+vb3dE9FRaVsvhHQAkdQHfBT4WES9Kugu4DYh0/wXgA9MdbERsADYA9PT0RG9vb0PbGR4ept51\n1w7cX7XPoWsaG0+jGqmjiFxHsbiO4mlVLTWFvqRTKQX+tyLiewAR8XzZ8q8BW9PsKLCgbPXzUhtT\ntJuZWQvUcvWOgI3AkxHxxbL2eWXd3gvsT9NDwGpJp0taBCwGHgN2AoslLZJ0GqWTvUPNKcPMzGpR\ny57+O4BrgX2S9qS2TwBXS1pC6fDOIeBDABFxQNIWSidoTwDrIuIVAEk3Ag8Cs4BNEXGgibWYmVkV\ntVy98wigCou2TbHO7cDtFdq3TbWemZmdXP5ErplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZ\nceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZ\nRhz6ZmYZceibmWXEoW9mlpGqoS9pgaSHJD0h6YCkj6b2syVtl3Qw3c9N7ZJ0h6QRSXslXVy2rTWp\n/0FJa05eWWZmVkkte/ongI9HxPnApcA6SecDA8COiFgM7EjzAFcCi9OtH7gLSi8SwC3AMmApcMv4\nC4WZmbVG1dCPiMMR8eM0/WvgSWA+sArYnLptBq5K06uAu6PkUWCOpHnAFcD2iDgaEceA7cCKplZj\nZmZTUkTU3llaCDwMXAj8d0TMSe0CjkXEHElbgfUR8UhatgO4CegFzoiIz6b2TwG/jYjPT3iMfkrv\nEOju7r5kcHCwocLGxsbo6uqqa519o8er9rlo/uyGxtOoRuooItdRLK6jeJpZS19f3+6I6Km07JRa\nNyKpC/gu8LGIeLGU8yUREZJqf/WYQkRsADYA9PT0RG9vb0PbGR4ept511w7cX7XPoWsaG0+jGqmj\niFxHsbiO4mlVLTVdvSPpVEqB/62I+F5qfj4dtiHdH0nto8CCstXPS22TtZuZWYvUcvWOgI3AkxHx\nxbJFQ8D4FThrgPvK2q9LV/FcChyPiMPAg8BySXPTCdzlqc3MzFqklsM77wCuBfZJ2pPaPgGsB7ZI\nugF4Fnh/WrYNWAmMAL8BrgeIiKOSbgN2pn6fiYijTanCzMxqUjX00wlZTbL48gr9A1g3ybY2AZvq\nGaCZmTWPP5FrZpYRh76ZWUZqvmTTShaWXdZ5aP172jgSM7P6eU/fzCwjDn0zs4w49M3MMuLQNzPL\niEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3M\nMuLQNzPLiEPfzCwjVUNf0iZJRyTtL2u7VdKopD3ptrJs2c2SRiQ9JemKsvYVqW1E0kDzSzEzs2pq\n2dP/OrCiQvuXImJJum0DkHQ+sBq4IK3zFUmzJM0C7gSuBM4Hrk59zcyshar+Rm5EPCxpYY3bWwUM\nRsTLwDOSRoCladlIRDwNIGkw9X2i7hGbmVnDFBHVO5VCf2tEXJjmbwXWAi8Cu4CPR8QxSV8GHo2I\nb6Z+G4EH0mZWRMQHU/u1wLKIuLHCY/UD/QDd3d2XDA4ONlTY2NgYXV1dda2zb/R4Q48FcNH82Q2v\nO5VG6igi11EsrqN4mllLX1/f7ojoqbSs6p7+JO4CbgMi3X8B+ECD23qNiNgAbADo6emJ3t7ehrYz\nPDxMveuuHbi/occCOHRNfY9Vq0bqKCLXUSyuo3haVUtDoR8Rz49PS/oasDXNjgILyrqel9qYot3M\nzFqkoUs2Jc0rm30vMH5lzxCwWtLpkhYBi4HHgJ3AYkmLJJ1G6WTvUOPDNjOzRlTd05f0baAXOEfS\nc8AtQK+kJZQO7xwCPgQQEQckbaF0gvYEsC4iXknbuRF4EJgFbIqIA02vxszMplTL1TtXV2jeOEX/\n24HbK7RvA7bVNTozM2sqfyLXzCwjDn0zs4w49M3MMtLodfozxsJpXJtvZtZpvKdvZpYRh76ZWUYc\n+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llJPsPZzVL+Ye8Dq1/TxtHYmY2Oe/pm5llJMs9\nfX/1gpnlynv6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZqRr6kjZJOiJpf1nb2ZK2SzqY7uemdkm6\nQ9KIpL2SLi5bZ03qf1DSmpNTjpmZTaWWPf2vAysmtA0AOyJiMbAjzQNcCSxOt37gLii9SAC3AMuA\npcAt4y8UZmbWOlVDPyIeBo5OaF4FbE7Tm4GrytrvjpJHgTmS5gFXANsj4mhEHAO284cvJGZmdpIp\nIqp3khYCWyPiwjT/QkTMSdMCjkXEHElbgfUR8UhatgO4CegFzoiIz6b2TwG/jYjPV3isfkrvEuju\n7r5kcHCwocLGxsbo6uqquGzf6PGGtlmri+bPbtq2pqqjk7iOYnEdxdPMWvr6+nZHRE+lZdP+RG5E\nhKTqrxy1b28DsAGgp6cnent7G9rO8PAwk6279mR/InffS69OTvd7eKaqo5O4jmJxHcXTqloavXrn\n+XTYhnR/JLWPAgvK+p2X2iZrNzOzFmo09IeA8Stw1gD3lbVfl67iuRQ4HhGHgQeB5ZLmphO4y1Ob\nmZm1UNXDO5K+TemY/DmSnqN0Fc56YIukG4Bngfen7tuAlcAI8BvgeoCIOCrpNmBn6veZiJh4ctjM\nzE6yqqEfEVdPsujyCn0DWDfJdjYBm+oanZmZNZU/kWtmlhGHvplZRhz6ZmYZceibmWUky59LbCX/\nYLqZFYn39M3MMpLNnr5/DN3MzHv6ZmZZceibmWXEoW9mlpFsjukXga/kMbN2856+mVlGHPpmZhlx\n6JuZZcShb2aWEYe+mVlGfPVOm/hKHjNrB+/pm5llxKFvZpYRh76ZWUamFfqSDknaJ2mPpF2p7WxJ\n2yUdTPdzU7sk3SFpRNJeSRc3owAzM6tdM07k9kXEL8vmB4AdEbFe0kCavwm4ElicbsuAu9J99nxS\n18xa5WQc3lkFbE7Tm4GrytrvjpJHgTmS5p2Exzczs0lMN/QD+L6k3ZL6U1t3RBxO078AutP0fODn\nZes+l9rMzKxFFBGNryzNj4hRSX8MbAf+FhiKiDllfY5FxFxJW4H1EfFIat8B3BQRuyZssx/oB+ju\n7r5kcHCwobGNjY3R1dX16vy+0eMNbafVLpo/+zXzE+voVK6jWFxH8TSzlr6+vt0R0VNp2bSO6UfE\naLo/IuleYCnwvKR5EXE4Hb45krqPAgvKVj8vtU3c5gZgA0BPT0/09vY2NLbh4WHK113bIT+XeOia\n3tfMT6yjU7mOYnEdxdOqWho+vCPpTElnjU8Dy4H9wBCwJnVbA9yXpoeA69JVPJcCx8sOA5mZWQtM\nZ0+/G7hX0vh2/j0i/lPSTmCLpBuAZ4H3p/7bgJXACPAb4PppPLaZmTWg4dCPiKeBt1Zo/xVweYX2\nANY1+ni58OWbZnYy+RO5BbZw4H72jR5/zQuBmdl0OPTNzDLir1buED7sY2bN4D19M7OMOPTNzDLi\n0Dczy4hD38wsIz6R24GmuoTTJ3nNbCre0zczy4j39GcYX9ppZlPxnr6ZWUYc+mZmGfHhnRnMh3rM\nbCKHfib8AmBm4NDPkl8AzPLl0M+cXwDM8uLQt1f5BcBs5vPVO2ZmGfGevlU02V7/ZF8B4XcGZp3B\noW9V1fJzjT40ZNYZZnTo+7dl26PSv/vHLzpBb+uHYmYTtDz0Ja0A/gWYBfxrRKxv9RisPZr1IjzZ\nO4mJ2/c7DrM/1NLQlzQLuBN4N/AcsFPSUEQ80cpxWGer9cXDh5zM/lCr9/SXAiMR8TSApEFgFeDQ\nt5OqlhPQfpGwHCgiWvdg0vuAFRHxwTR/LbAsIm4s69MP9KfZPweeavDhzgF+OY3hFoXrKBbXUSwz\npQ5obi1/GhHnVlpQuBO5EbEB2DDd7UjaFRE9TRhSW7mOYnEdxTJT6oDW1dLqD2eNAgvK5s9LbWZm\n1gKtDv2dwGJJiySdBqwGhlo8BjOzbLX08E5EnJB0I/AgpUs2N0XEgZP0cNM+RFQQrqNYXEexzJQ6\noEW1tPRErpmZtZe/cM3MLCMOfTOzjMy40Je0QtJTkkYkDbR7PNVIOiRpn6Q9knaltrMlbZd0MN3P\nTe2SdEeqba+ki9s47k2SjkjaX9ZW97glrUn9D0paU6BabpU0mp6XPZJWli27OdXylKQrytrb9rcn\naYGkhyQ9IemApI+m9o56Tqaoo6Oej/T4Z0h6TNLjqZZPp/ZFkn6UxvWddFELkk5P8yNp+cJqNTYk\nImbMjdLJ4Z8BbwROAx4Hzm/3uKqM+RBwzoS2fwIG0vQA8Lk0vRJ4ABBwKfCjNo77XcDFwP5Gxw2c\nDTyd7uem6bkFqeVW4O8r9D0//V2dDixKf2+z2v23B8wDLk7TZwE/TWPtqOdkijo66vlIYxPQlaZP\nBX6U/q23AKtT+1eBD6fpjwBfTdOrge9MVWOj45ppe/qvfs1DRPwOGP+ah06zCticpjcDV5W13x0l\njwJzJM1rxwAj4mHg6ITmesd9BbA9Io5GxDFgO7Di5I/+tSapZTKrgMGIeDkingFGKP3dtfVvLyIO\nR8SP0/SvgSeB+XTYczJFHZMp5PMBkP5tx9LsqekWwGXAPal94nMy/lzdA1wuSUxeY0NmWujPB35e\nNv8cU//BFEEA35e0W6WvoADojojDafoXQHeaLnp99Y676PXcmA59bBo/LEIH1JIOC7yN0p5lxz4n\nE+qADnw+JM2StAc4QukF9GfACxFxosK4Xh1zWn4c+COaXMtMC/1O9M6IuBi4Elgn6V3lC6P0/q7j\nrqvt1HGXuQv4M2AJcBj4QnuHUxtJXcB3gY9FxIvlyzrpOalQR0c+HxHxSkQsofTtA0uBt7R5SDMu\n9Dvuax4iYjTdHwHupfSH8fz4YZt0fyR1L3p99Y67sPVExPPpP+z/AV/j92+nC1uLpFMpBeW3IuJ7\nqbnjnpNKdXTi81EuIl4AHgLeTulQ2vgHY8vH9eqY0/LZwK9oci0zLfQ76mseJJ0p6azxaWA5sJ/S\nmMevmlgD3Jemh4Dr0pUXlwLHy966F0G9434QWC5pbnq7vjy1td2EcyXvpfS8QKmW1elKi0XAYuAx\n2vy3l479bgSejIgvli3qqOdksjo67flIYz5X0pw0/XpKvyPyJKXwf1/qNvE5GX+u3gf8ML07m6zG\nxrTybHYrbpSuSvgppWNnn2z3eKqM9Y2Uzso/DhwYHy+l43g7gIPAD4Cz4/dXA9yZatsH9LRx7N+m\n9Db7fykdY7yhkXEDH6B0YmoEuL5AtXwjjXVv+k83r6z/J1MtTwFXFuFvD3gnpUM3e4E96bay056T\nKeroqOcjPf5fAD9JY94P/GNqfyOl0B4B/gM4PbWfkeZH0vI3VquxkZu/hsHMLCMz7fCOmZlNwaFv\nZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUb+HylkQTlvcQd1AAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEICAYAAABF82P+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAXb0lEQVR4nO3df6zldZ3f8edLBDSMK7DYmymQDtbZ\nbFAq4g1g12yuGmHAEDSxBktkUDaz3YVUU9o47qbFH0uCzaqpKct2LFPRdR2pSpwgls4iN8Y/kB/u\n8GNA5ApjZDIyWYHRoy3bse/+cT4zns5+78y9556Ze87l+UhOzve8vz/O5833wovvj3NOqgpJkg72\nkuUegCRpPBkQkqROBoQkqZMBIUnqZEBIkjoZEJKkTgaEJKmTASEtUpKTk9yW5JdJfpzkXy73mKQj\n4aXLPQBpAt0I/D0wBZwNfDPJg1W1Y3mHJY1W/CS1tHBJTgCeA15XVT9stS8Cu6pq47IOThoxTzFJ\ni/M7wL794dA8CLx2mcYjHTEGhLQ4q4CfH1TbC7xiGcYiHVEGhLQ4PeC3Dqr9FvCLZRiLdEQZENLi\n/BB4aZK1A7XXA16g1orjRWppkZJsAQr4A/p3Md0B/HPvYtJK4xGEtHh/DLwc2AN8Gfgjw0ErkUcQ\nkqROHkFIkjoZEJKkTgaEJKmTASFJ6jTWX9Z3yimn1Jo1a4Za95e//CUnnHDCaAe0zOxpMtjTZFiJ\nPUG/rx/84Ad/V1WvWuq2xjog1qxZw/333z/UurOzs8zMzIx2QMvMniaDPU2GldgT9Pt6y1ve8uNR\nbMtTTJKkTgaEJKmTASFJ6mRASJI6GRCSpE4GhCSpkwEhSepkQEiSOhkQkqROY/1J6uW0ZuM3D0zv\nvOEdyzgSSVoeHkFIkjoZEJKkTgaEJKnTYQMiycuS3JvkwSQ7knys1T+f5Kkk29vj7FZPks8mmUvy\nUJJzBra1PskT7bH+yLUlSVqqhVykfgF4a1X1khwLfDfJt9q8f1dVXz1o+YuAte1xHnATcF6Sk4Hr\ngGmggAeSbK2q50bRiCRptA57BFF9vfby2PaoQ6xyKfCFtt49wIlJVgMXAtuq6tkWCtuAdUsbviTp\nSEnVof5b3xZKjgEeAF4D3FhVH07yeeBN9I8w7gI2VtULSW4Hbqiq77Z17wI+DMwAL6uqP2v1fw/8\nr6r684PeawOwAWBqauqNW7ZsGaqxXq/HqlWrhloX4OFdew9Mn3XqK4fezigttadxZE+TwZ4mR6/X\n45JLLnmgqqaXuq0FfQ6iqn4NnJ3kROC2JK8DPgL8FDgO2EQ/BD6+1AFV1aa2Paanp2vYX3xa6q9F\nXTn4OYjLh9/OKK3EX8Cyp8lgT5NjdnZ2ZNta1F1MVfU8cDewrqp2t9NILwD/DTi3LbYLOH1gtdNa\nbb66JGkMLeQuple1IweSvBx4O/CDdl2BJAHeCTzSVtkKXNHuZjof2FtVu4E7gQuSnJTkJOCCVpMk\njaGFnGJaDdzSrkO8BLi1qm5P8u0krwICbAf+VVv+DuBiYA74FfB+gKp6NskngPvach+vqmdH14ok\naZQOGxBV9RDwho76W+dZvoCr55m3Gdi8yDFKkpaBn6SWJHUyICRJnQwISVInA0KS1MmAkCR1MiAk\nSZ0MCElSJwNCktTJgJAkdTIgJEmdDAhJUicDQpLUyYCQJHUyICRJnQwISVInA0KS1MmAkCR1MiAk\nSZ0MCElSJwNCktTpsAGR5GVJ7k3yYJIdST7W6mck+V6SuSRfSXJcqx/fXs+1+WsGtvWRVn88yYVH\nqilJ0tIt5AjiBeCtVfV64GxgXZLzgU8Cn6mq1wDPAVe15a8Cnmv1z7TlSHImcBnwWmAd8BdJjhll\nM5Kk0TlsQFRfr708tj0KeCvw1Va/BXhnm760vabNf1uStPqWqnqhqp4C5oBzR9KFJGnkXrqQhdr/\n6T8AvAa4EfgR8HxV7WuLPA2c2qZPBX4CUFX7kuwFfrvV7xnY7OA6g++1AdgAMDU1xezs7OI6anq9\n3tDrAlx71r4D00vZzigttadxZE+TwZ4mR6/XO/xCC7SggKiqXwNnJzkRuA343ZGN4B++1yZgE8D0\n9HTNzMwMtZ3Z2VmGXRfgyo3fPDC98/LhtzNKS+1pHNnTZLCnyTHK0FvUXUxV9TxwN/Am4MQk+wPm\nNGBXm94FnA7Q5r8S+NlgvWMdSdKYWchdTK9qRw4keTnwduAx+kHx7rbYeuAbbXpre02b/+2qqla/\nrN3ldAawFrh3VI1IkkZrIaeYVgO3tOsQLwFurarbkzwKbEnyZ8DfAje35W8GvphkDniW/p1LVNWO\nJLcCjwL7gKvbqStJ0hg6bEBU1UPAGzrqT9JxF1JV/W/gX8yzreuB6xc/TEnS0eYnqSVJnQwISVIn\nA0KS1MmAkCR1MiAkSZ0MCElSJwNCktTJgJAkdTIgJEmdFvRtri92awa/2fWGdyzjSCTp6PEIQpLU\nyYCQJHUyICRJnQwISVInA0KS1MmAkCR1MiAkSZ0MCElSJwNCktTpsAGR5PQkdyd5NMmOJB9s9Y8m\n2ZVke3tcPLDOR5LMJXk8yYUD9XWtNpdk45FpSZI0Cgv5qo19wLVV9f0krwAeSLKtzftMVf354MJJ\nzgQuA14L/GPgb5L8Tpt9I/B24GngviRbq+rRUTQiSRqtwwZEVe0GdrfpXyR5DDj1EKtcCmypqheA\np5LMAee2eXNV9SRAki1tWQNCksZQqmrhCydrgO8ArwP+DXAl8HPgfvpHGc8l+c/APVX1V22dm4Fv\ntU2sq6o/aPX3AedV1TUHvccGYAPA1NTUG7ds2TJUY71ej1WrVg21LsDDu/Z21s869ZVDb3OpltrT\nOLKnyWBPk6PX63HJJZc8UFXTS93Wgr/NNckq4GvAh6rq50luAj4BVHv+FPCBpQ6oqjYBmwCmp6dr\nZmZmqO3Mzs6y2HUHv7V1vn80Oy8fbjyjMExP486eJoM9TY7Z2dmRbWtBAZHkWPrh8KWq+jpAVT0z\nMP9zwO3t5S7g9IHVT2s1DlGXJI2ZhdzFFOBm4LGq+vRAffXAYu8CHmnTW4HLkhyf5AxgLXAvcB+w\nNskZSY6jfyF762jakCSN2kKOIH4PeB/wcJLtrfYnwHuTnE3/FNNO4A8BqmpHklvpX3zeB1xdVb8G\nSHINcCdwDLC5qnaMsBdJ0ggt5C6m7wLpmHXHIda5Hri+o37HodaTJI0PP0ktSepkQEiSOhkQkqRO\nBoQkqZMBIUnqZEBIkjoZEJKkTgaEJKmTASFJ6mRASJI6GRCSpE4GhCSpkwEhSepkQEiSOhkQkqRO\nBoQkqZMBIUnqtJCfHNWANRu/eWB65w3vWMaRSNKR5RGEJKmTASFJ6nTYgEhyepK7kzyaZEeSD7b6\nyUm2JXmiPZ/U6kny2SRzSR5Kcs7Atta35Z9Isv7ItSVJWqqFHEHsA66tqjOB84Grk5wJbATuqqq1\nwF3tNcBFwNr22ADcBP1AAa4DzgPOBa7bHyqSpPFz2ICoqt1V9f02/QvgMeBU4FLglrbYLcA72/Sl\nwBeq7x7gxCSrgQuBbVX1bFU9B2wD1o20G0nSyCzqLqYka4A3AN8Dpqpqd5v1U2CqTZ8K/GRgtadb\nbb76we+xgf6RB1NTU8zOzi5miAf0er1Fr3vtWfsWtfywYxvWMD2NO3uaDPY0OXq93si2teCASLIK\n+Brwoar6eZID86qqktQoBlRVm4BNANPT0zUzMzPUdmZnZ1nsulcO3MK6EDsvX9z2l2qYnsadPU0G\ne5ocowy9Bd3FlORY+uHwpar6eis/004d0Z73tPou4PSB1U9rtfnqkqQxtJC7mALcDDxWVZ8emLUV\n2H8n0nrgGwP1K9rdTOcDe9upqDuBC5Kc1C5OX9BqkqQxtJBTTL8HvA94OMn2VvsT4Abg1iRXAT8G\n3tPm3QFcDMwBvwLeD1BVzyb5BHBfW+7jVfXsSLqQJI3cYQOiqr4LZJ7Zb+tYvoCr59nWZmDzYgYo\nSVoefpJaktTJgJAkdTIgJEmdDAhJUicDQpLUyYCQJHUyICRJnQwISVInA0KS1MmAkCR1MiAkSZ0M\nCElSJwNCktTJgJAkdVrUb1Lr/7dm4CdKd97wjmUciSSNnkcQkqROBoQkqZMBIUnqZEBIkjodNiCS\nbE6yJ8kjA7WPJtmVZHt7XDww7yNJ5pI8nuTCgfq6VptLsnH0rUiSRmkhRxCfB9Z11D9TVWe3xx0A\nSc4ELgNe29b5iyTHJDkGuBG4CDgTeG9bVpI0pg57m2tVfSfJmgVu71JgS1W9ADyVZA44t82bq6on\nAZJsacs+uugRS5KOiqV8DuKaJFcA9wPXVtVzwKnAPQPLPN1qAD85qH5e10aTbAA2AExNTTE7OzvU\n4Hq93qLXvfasfUO9FzD0OBdjmJ7GnT1NBnuaHL1eb2TbGjYgbgI+AVR7/hTwgVEMqKo2AZsApqen\na2ZmZqjtzM7Osth1rxz44Nti7bx8ce81jGF6Gnf2NBnsaXKMMvSGCoiqemb/dJLPAbe3l7uA0wcW\nPa3VOERdkjSGhrrNNcnqgZfvAvbf4bQVuCzJ8UnOANYC9wL3AWuTnJHkOPoXsrcOP2xJ0pF22COI\nJF8GZoBTkjwNXAfMJDmb/immncAfAlTVjiS30r/4vA+4uqp+3bZzDXAncAywuap2jLwbSdLILOQu\npvd2lG8+xPLXA9d31O8A7ljU6CRJy8ZPUkuSOhkQkqROBoQkqZMBIUnqZEBIkjr5k6Mj4s+PSlpp\nPIKQJHUyICRJnQwISVInA0KS1MmAkCR1MiAkSZ0MCElSJwNCktTpRf9BuTVL+JlRSVrJPIKQJHUy\nICRJnV6Up5g8rSRJh+cRhCSpkwEhSep02IBIsjnJniSPDNROTrItyRPt+aRWT5LPJplL8lCScwbW\nWd+WfyLJ+iPTjiRpVBZyBPF5YN1BtY3AXVW1FrirvQa4CFjbHhuAm6AfKMB1wHnAucB1+0NFkjSe\nDnuRuqq+k2TNQeVLgZk2fQswC3y41b9QVQXck+TEJKvbstuq6lmAJNvoh86Xl9zBGPLHgyStBMPe\nxTRVVbvb9E+BqTZ9KvCTgeWebrX56v9Akg30jz6YmppidnZ2qAH2er151732rH1DbXMYw46/y6F6\nmlT2NBnsaXL0er2RbWvJt7lWVSWpUQymbW8TsAlgenq6ZmZmhtrO7Ows86175VG8zXXn5d1jGMah\neppU9jQZ7GlyjDL0hr2L6Zl26oj2vKfVdwGnDyx3WqvNV5ckjalhA2IrsP9OpPXANwbqV7S7mc4H\n9rZTUXcCFyQ5qV2cvqDVJElj6rCnmJJ8mf5F5lOSPE3/bqQbgFuTXAX8GHhPW/wO4GJgDvgV8H6A\nqno2ySeA+9pyH99/wVqSNJ4WchfTe+eZ9baOZQu4ep7tbAY2L2p0kqRl4yepJUmdXpRf1nc0+ZkI\nSZPKIwhJUicDQpLUyVNMR5GnmyRNEo8gJEmdDAhJUicDQpLUyYCQJHUyICRJnQwISVInA0KS1MnP\nQSwTPxMhady9aAJizVH8FTlJWgk8xSRJ6mRASJI6GRCSpE4GhCSpkwEhSeq0pIBIsjPJw0m2J7m/\n1U5Osi3JE+35pFZPks8mmUvyUJJzRtGAJOnIGMURxFuq6uyqmm6vNwJ3VdVa4K72GuAiYG17bABu\nGsF7S5KOkCNxiulS4JY2fQvwzoH6F6rvHuDEJKuPwPtLkkYgVTX8yslTwHNAAf+lqjYleb6qTmzz\nAzxXVScmuR24oaq+2+bdBXy4qu4/aJsb6B9hMDU19cYtW7YMNbZer8eqVasOvH54196htnO0nXXq\nK+edd3BPK4E9TQZ7mhy9Xo9LLrnkgYGzOkNb6iep31xVu5L8I2Bbkh8MzqyqSrKoBKqqTcAmgOnp\n6ZqZmRlqYLOzswyue+WEfJJ65+Uz8847uKeVwJ4mgz1NjtnZ2ZFta0kBUVW72vOeJLcB5wLPJFld\nVbvbKaQ9bfFdwOkDq5/Wahow31eC+H1Nko62oa9BJDkhySv2TwMXAI8AW4H1bbH1wDfa9FbginY3\n0/nA3qraPfTIJUlH1FKOIKaA2/qXGXgp8NdV9T+S3AfcmuQq4MfAe9rydwAXA3PAr4D3L+G9JUlH\n2NABUVVPAq/vqP8MeFtHvYCrh30/SdLR5SepJUmdDAhJUicDYkKs2fhNHt611x8+knTUGBCSpE4G\nhCSpkwEhSeq01K/a0DI41HUIP3EtaVQ8gpAkdfIIYoUZPLrwaELSUhgQK5hhIWkpPMUkSepkQEiS\nOnmK6UXIU0+SFsKAeJHwKzokLZYB8SLn0YSk+RgQOsCwkDTIgFAnw0KSAaHDmi8sDBFpZTMgtCgL\nudhtcEgrgwGhkZgvOBYSFgaKNJ6OekAkWQf8J+AY4L9W1Q1H6r28tXO8dO2Pa8/ax3x/hoaLtLyO\nakAkOQa4EXg78DRwX5KtVfXo0RyHxtdCjkQWu+6hGCrS/I72EcS5wFxVPQmQZAtwKWBAaFnMFyoG\nh3T0A+JU4CcDr58GzhtcIMkGYEN72Uvy+JDvdQrwd0OuO5b+tT0dNfnkklYfy56WyJ4mxynAPxnF\nhsbuInVVbQI2LXU7Se6vqukRDGls2NNksKfJsBJ7ggN9rRnFto72t7nuAk4feH1aq0mSxszRDoj7\ngLVJzkhyHHAZsPUoj0GStABH9RRTVe1Lcg1wJ/3bXDdX1Y4j9HZLPk01huxpMtjTZFiJPcEI+0pV\njWpbkqQVxF+UkyR1MiAkSZ1WXEAkWZfk8SRzSTYu93gWI8nOJA8n2Z7k/lY7Ocm2JE+055NaPUk+\n2/p8KMk5yzv630iyOcmeJI8M1BbdR5L1bfknkqxfjl4GxtLV00eT7Gr7a3uSiwfmfaT19HiSCwfq\nY/H3meT0JHcneTTJjiQfbPVJ30/z9TXJ++plSe5N8mDr6WOtfkaS77XxfaXd+EOS49vruTZ/zcC2\nOnudV1WtmAf9C98/Al4NHAc8CJy53ONaxPh3AqccVPuPwMY2vRH4ZJu+GPgWEOB84HvLPf6BMf8+\ncA7wyLB9ACcDT7bnk9r0SWPW00eBf9ux7Jntb+944Iz2N3nMOP19AquBc9r0K4AftnFP+n6ar69J\n3lcBVrXpY4HvtX1wK3BZq/8l8Edt+o+Bv2zTlwFfOVSvh3rvlXYEceCrPKrq74H9X+UxyS4FbmnT\ntwDvHKh/ofruAU5Msno5BniwqvoO8OxB5cX2cSGwraqerarngG3AuiM/+m7z9DSfS4EtVfVCVT0F\nzNH/2xybv8+q2l1V32/TvwAeo/9NB5O+n+braz6TsK+qqnrt5bHtUcBbga+2+sH7av8+/CrwtiRh\n/l7ntdICouurPA71xzFuCvifSR5I/ytHAKaqaneb/ikw1aYnrdfF9jEp/V3TTrls3n86hgnrqZ2C\neAP9/zNdMfvpoL5ggvdVkmOSbAf20A/hHwHPV9W+jvEdGHubvxf4bYboaaUFxKR7c1WdA1wEXJ3k\n9wdnVv84ceLvS14pfQA3Af8UOBvYDXxqeYezeElWAV8DPlRVPx+cN8n7qaOvid5XVfXrqjqb/rdP\nnAv87tF435UWEBP9VR5Vtas97wFuo/+H8Mz+U0fteU9bfNJ6XWwfY99fVT3T/sX9v8Dn+M3h+kT0\nlORY+v8R/VJVfb2VJ34/dfU16ftqv6p6HrgbeBP903z7P+w8OL4DY2/zXwn8jCF6WmkBMbFf5ZHk\nhCSv2D8NXAA8Qn/8++8MWQ98o01vBa5od5ecD+wdODUwjhbbx53ABUlOaqcDLmi1sXHQNZ930d9f\n0O/psnY3yRnAWuBexujvs52Tvhl4rKo+PTBrovfTfH1N+L56VZIT2/TL6f+ezmP0g+LdbbGD99X+\nffhu4NvtaHC+Xue3HFflj+SD/t0WP6R/ju5Pl3s8ixj3q+nfYfAgsGP/2OmfO7wLeAL4G+Dk+s2d\nDTe2Ph8Gppe7h4Fevkz/MP7/0D/PedUwfQAfoH8hbQ54/xj29MU25ofav3yrB5b/09bT48BF4/b3\nCbyZ/umjh4Dt7XHxCthP8/U1yfvqnwF/28b+CPAfWv3V9P8DPwf8d+D4Vn9Zez3X5r/6cL3O9/Cr\nNiRJnVbaKSZJ0ogYEJKkTgaEJKmTASFJ6mRASJI6GRCSpE4GhCSp0/8DKcNVR821TecAAAAASUVO\nRK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "olroXq3_WiQt",
        "colab_type": "text"
      },
      "source": [
        "# The Model (finally)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ccp5trMwRtmr",
        "colab_type": "text"
      },
      "source": [
        "\n",
        "Now let's create a classification model using [adapter-BERT](https//arxiv.org/abs/1902.00751), which is a clever way \n",
        "of reducing the trainable parameter count, by freezing the original BERT weights, \n",
        "and adapting the internal activations with two FFN bottlenecks (i.e. `adapter_size` bellow) in every BERT layer. \n",
        "\n",
        "For sequence classification BERT proposes a classifier acting on the `[CLS]` output alone. Such a classifier overfits easily and is difficult to regularize. \n",
        "As an alternative we'll be using max pooling on the complete output sequence.  For regularization we rely entirely upon layer normalization and the selection of small sizes for `adapter_size` and the classifier layers (and one cycle learning policy with a big learning rate).\n",
        "\n",
        "(The intuition being, that the `[CLS]` output during pre-training is never trained to capture global sequence representations, and would therefore need more adaption for finding the optimal `[CLS]` representation during fine-tunning, while at the same time this would overfit all the other activations.)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6o2a5ZIvRcJq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def create_model(max_seq_len, \n",
        "                 adapter_size=64,\n",
        "                 batch_size=None,\n",
        "                 init_ckpt_file=None,\n",
        "                 init_bert_ckpt_file=bert_ckpt_file,\n",
        "                ):\n",
        "  \"\"\"Creates a classification model.\n",
        "  :param adapter_size: adapter bottleneck size - arXiv:1902.00751\n",
        "  \"\"\"\n",
        "\n",
        "  bert_params = params_from_pretrained_ckpt(os.path.dirname(init_bert_ckpt_file))\n",
        "  \n",
        "  # create the bert layer\n",
        "  bert_params.adapter_size = adapter_size\n",
        "  bert_params.adapter_init_scale = 1e-5\n",
        "  l_bert = BertModelLayer.from_params(bert_params, name=\"bert\")\n",
        "\n",
        "  max_pooling = True\n",
        "  if max_pooling:\n",
        "    model = keras.models.Sequential([\n",
        "        keras.layers.InputLayer(input_shape=(max_seq_len,),\n",
        "                                batch_size=batch_size,\n",
        "                                dtype=\"int32\", name=\"input_ids\"),\n",
        "        l_bert,\n",
        "\n",
        "        #keras.layers.TimeDistributed(keras.layers.Dropout(0.1)),\n",
        "        keras.layers.TimeDistributed(keras.layers.Dense(bert_params.hidden_size//32)),\n",
        "        keras.layers.TimeDistributed(keras.layers.LayerNormalization()),\n",
        "        keras.layers.TimeDistributed(keras.layers.Activation(\"tanh\")),\n",
        "\n",
        "        pf.Concat([\n",
        "          keras.layers.Lambda(lambda x: tf.math.reduce_max(x, axis=1)),  # GlobalMaxPooling1D   \n",
        "          keras.layers.GlobalAveragePooling1D()\n",
        "        ]),\n",
        "\n",
        "        #keras.layers.Dropout(0.5),\n",
        "        keras.layers.Dense(units=bert_params.hidden_size//16),\n",
        "        keras.layers.LayerNormalization(),\n",
        "        keras.layers.Activation(\"tanh\"),\n",
        "\n",
        "        keras.layers.Dense(units=2)\n",
        "    ])\n",
        "  else:\n",
        "    model = keras.models.Sequential([\n",
        "        keras.layers.InputLayer(input_shape=(max_seq_len,),\n",
        "                                batch_size=batch_size,\n",
        "                                dtype=\"int32\", name=\"input_ids\"),\n",
        "        l_bert,\n",
        "        keras.layers.Lambda(lambda seq: seq[:, 0, :]),\n",
        "        keras.layers.Dense(units=bert_params.hidden_size),\n",
        "        keras.layers.Activation(\"tanh\"),\n",
        "        keras.layers.Dense(units=2)      \n",
        "    ])\n",
        "  \n",
        "  model.build(input_shape=(batch_size, max_seq_len))\n",
        "  \n",
        "  # freeze non-adapter-BERT layers for the case adapter_size is set\n",
        "  l_bert.apply_adapter_freeze()\n",
        "  l_bert.embeddings_layer.trainable=False \n",
        "  # using True above will unfreezing the emb LayerNorms for the case \n",
        "  # adapter freeze was applied (N.B. False is best in both cases)\n",
        "  \n",
        "  # apply global regularization on all trainable dense layers\n",
        "  pf.utils.add_dense_layer_loss(model,\n",
        "                                kernel_regularizer=keras.regularizers.l2(0.01),\n",
        "                                bias_regularizer=keras.regularizers.l2(0.01))\n",
        "  \n",
        "  model.compile(optimizer=pf.optimizers.RAdam(),\n",
        "                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                metrics=[keras.metrics.SparseCategoricalAccuracy(name=\"acc\")])\n",
        "\n",
        "  # pf.optimizers.lookahead.OptimizerLookaheadWrapper().wrap(model) # fill an issue\n",
        "  \n",
        "  # load the pre-trained model weights (once the input_shape is known)\n",
        "  if init_ckpt_file:\n",
        "    print(\"Loading model weights from:\", init_ckpt_file)\n",
        "    model.load_weights(init_ckpt_file)\n",
        "  elif init_bert_ckpt_file:\n",
        "    print(\"Loading pre-trained BERT layer from:\", init_bert_ckpt_file)\n",
        "    load_stock_weights(l_bert, init_bert_ckpt_file)\n",
        "\n",
        "      \n",
        "  return model\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fMEsolbw4Xo-",
        "colab_type": "text"
      },
      "source": [
        "Bigger `max_seq_len` in a transformer model slows things quadratically, \n",
        "but being equipted with a free TPU, we'd go for the maximum sequence length, which in is 512 in BERT. We are chossing a realy small `adapter_size` and a rather huge learning rate, both of which are needed for regularization:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9cyROAuVljmp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "adapter_size = 4\n",
        "max_seq_len  = 512\n",
        "batch_size   = 128"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kTy5V6Js5FTo",
        "colab_type": "text"
      },
      "source": [
        "So we are now finally ready to create our model. And to make it run on a TPU, all we need is to wrap the model creation in a `TPUStrategy` scope:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mAXDrezB0ijW",
        "colab_type": "code",
        "outputId": "eef3c748-8134-432c-b4e2-08129ac602fa",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "\n",
        "tf.config.experimental_connect_to_cluster(tpu)\n",
        "tf.tpu.experimental.initialize_tpu_system(tpu)\n",
        "tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
        "\n",
        "\n",
        "with tpu_strategy.scope():\n",
        "  model = create_model(max_seq_len,\n",
        "                       adapter_size, \n",
        "                       batch_size=batch_size,\n",
        "                       init_bert_ckpt_file=bert_ckpt_file)\n",
        "\n",
        "model.summary()"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Initializing the TPU system: grpc://10.124.147.138:8470\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Initializing the TPU system: grpc://10.124.147.138:8470\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Clearing out eager caches\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Clearing out eager caches\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Finished initializing TPU system.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Finished initializing TPU system.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Found TPU system:\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Found TPU system:\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Cores: 8\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Cores: 8\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Workers: 1\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Workers: 1\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Loading pre-trained BERT layer from: gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt\n",
            "loader: No value for:[bert/encoder/layer_0/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_0/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_0/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_0/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_0/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_0/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_0/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_0/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_0/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_0/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_1/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_1/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_1/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_1/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_1/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_1/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_1/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_1/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_1/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_2/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_2/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_2/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_2/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_2/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_2/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_2/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_2/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_2/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_3/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_3/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_3/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_3/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_3/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_3/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_3/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_3/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_3/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_4/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_4/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_4/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_4/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_4/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_4/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_4/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_4/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_4/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_5/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_5/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_5/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_5/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_5/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_5/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_5/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_5/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_5/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_6/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_6/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_6/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_6/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_6/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_6/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_6/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_6/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_6/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_7/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_7/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_7/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_7/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_7/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_7/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_7/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_7/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_7/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_8/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_8/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_8/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_8/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_8/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_8/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_8/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_8/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_8/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_9/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_9/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_9/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_9/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_9/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_9/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_9/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_9/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_9/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_10/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_10/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_10/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_10/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_10/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_10/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_10/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_10/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_10/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/attention/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_11/attention/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/attention/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_11/attention/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/attention/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_11/attention/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/attention/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_11/attention/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/output/adapter-down/kernel:0], i.e.:[bert/encoder/layer_11/output/adapter-down/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/output/adapter-down/bias:0], i.e.:[bert/encoder/layer_11/output/adapter-down/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/output/adapter-up/kernel:0], i.e.:[bert/encoder/layer_11/output/adapter-up/kernel] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "loader: No value for:[bert/encoder/layer_11/output/adapter-up/bias:0], i.e.:[bert/encoder/layer_11/output/adapter-up/bias] in:[gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt]\n",
            "Done loading 196 BERT weights from: gs://kpe/colab/tpu_movie_reviews/bert_models/uncased_L-12_H-768_A-12/bert_model.ckpt into <bert.model.BertModelLayer object at 0x7f182ab50b00> (prefix:bert). Count of weights not found in the checkpoint was: [96]. Count of weights with mismatched shape: [0]\n",
            "Unused weights from checkpoint: \n",
            "\tbert/embeddings/token_type_embeddings\n",
            "\tbert/pooler/dense/bias\n",
            "\tbert/pooler/dense/kernel\n",
            "\tcls/predictions/output_bias\n",
            "\tcls/predictions/transform/LayerNorm/beta\n",
            "\tcls/predictions/transform/LayerNorm/gamma\n",
            "\tcls/predictions/transform/dense/bias\n",
            "\tcls/predictions/transform/dense/kernel\n",
            "\tcls/seq_relationship/output_bias\n",
            "\tcls/seq_relationship/output_weights\n",
            "Model: \"sequential\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "bert (BertModelLayer)        (16, 512, 768)            109056096 \n",
            "_________________________________________________________________\n",
            "time_distributed (TimeDistri (16, 512, 24)             18456     \n",
            "_________________________________________________________________\n",
            "time_distributed_1 (TimeDist (16, 512, 24)             48        \n",
            "_________________________________________________________________\n",
            "time_distributed_2 (TimeDist (16, 512, 24)             0         \n",
            "_________________________________________________________________\n",
            "concat (Concat)              (16, 48)                  0         \n",
            "_________________________________________________________________\n",
            "dense_1 (Dense)              (16, 48)                  2352      \n",
            "_________________________________________________________________\n",
            "layer_normalization_1 (Layer (16, 48)                  96        \n",
            "_________________________________________________________________\n",
            "activation_1 (Activation)    (16, 48)                  0         \n",
            "_________________________________________________________________\n",
            "dense_2 (Dense)              (16, 2)                   98        \n",
            "=================================================================\n",
            "Total params: 109,077,146\n",
            "Trainable params: 223,898\n",
            "Non-trainable params: 108,853,248\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SLmgCxwjacne",
        "colab_type": "text"
      },
      "source": [
        "We can now prepare the dataset, by using `drop_remainder=True` while batching for making sure the `batch_size` is fixed which should help the TPU run more efficiently:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P8yFOEiIkswR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "ds = tfrecord_to_dataset([train_tfrecord_file])\n",
        "ds = ds.map(create_pad_example_fn(pad_len=max_seq_len))\n",
        "ds = ds.cache()\n",
        "ds = ds.shuffle(buffer_size=25000, seed=4711, reshuffle_each_iteration=True)\n",
        "ds = ds.repeat()\n",
        "\n",
        "ds = ds.batch(batch_size, drop_remainder=True).prefetch(1)\n",
        "train_ds = ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5jFBFOOK46iF",
        "colab_type": "text"
      },
      "source": [
        "Once the model is trained, we'll store it in our bucket:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XrQ6Zh2PigjO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "trained_ckpt_file = os.path.join(OUTPUT_DIR, 'checkpoints','trained','movie_reviews.ckpt')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "chq22MBJ5DEw",
        "colab_type": "text"
      },
      "source": [
        "and now lets proceed with the actual training on the TPU. We'll be using one cycle learning policy with a huge learning rate (which also works as a regularizer).\n",
        "\n",
        "**N.B.** When you run this for the first time chances are it will fail because the TPU has no permission to your bucket, so check https://cloud.google.com/tpu/docs/storage-buckets and add the Project Viewer and Storage Admin permissions for the TPU's service account shown in the error message."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZuLOkwonF-9S",
        "colab_type": "code",
        "outputId": "7c0dd780-c949-4964-eb62-89a995f0297f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "#%%time\n",
        "\n",
        "if tf.io.gfile.exists(trained_ckpt_file):\n",
        "  model.load_weights(trained_ckpt_file)\n",
        "else:\n",
        "  log_dir = os.path.join(OUTPUT_DIR, \"log\", datetime.datetime.now().strftime(\"%Y%m%d-%H%M%s\"))\n",
        "  tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir)\n",
        "\n",
        "  total_epoch_count = 30\n",
        "  lr_scheduler = pf.utils.create_one_cycle_lr_scheduler(max_learn_rate=2e-3,\n",
        "                                                        end_learn_rate=1e-6,\n",
        "                                                        warmup_epoch_count=20,\n",
        "                                                        total_epoch_count=total_epoch_count)\n",
        "\n",
        "  model.fit(train_ds,\n",
        "            shuffle=True,\n",
        "            epochs=total_epoch_count,\n",
        "            steps_per_epoch=25000//batch_size,\n",
        "            callbacks=[lr_scheduler,\n",
        "                       keras.callbacks.EarlyStopping(patience=10, \n",
        "                                                     restore_best_weights=True, \n",
        "                                                     monitor='loss'), # TODO: validation on a TPU - how?\n",
        "                       tensorboard_callback])\n",
        "  model.save_weights(trained_ckpt_file, overwrite=True)"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "Epoch 00001: LearningRateScheduler reducing learning rate to 0.0001.\n",
            "Epoch 1/30\n",
            "  2/195 [..............................] - ETA: 5:03 - loss: 1.3537 - acc: 0.4648WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (1.473135). Check your callbacks.\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (1.473135). Check your callbacks.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "195/195 [==============================] - 71s 363ms/step - loss: 0.8733 - acc: 0.5185 - lr: 1.0000e-04\n",
            "\n",
            "Epoch 00002: LearningRateScheduler reducing learning rate to 0.0002.\n",
            "Epoch 2/30\n",
            "195/195 [==============================] - 68s 349ms/step - loss: 0.6503 - acc: 0.7055 - lr: 2.0000e-04\n",
            "\n",
            "Epoch 00003: LearningRateScheduler reducing learning rate to 0.00030000000000000003.\n",
            "Epoch 3/30\n",
            "195/195 [==============================] - 68s 349ms/step - loss: 0.3525 - acc: 0.8838 - lr: 3.0000e-04\n",
            "\n",
            "Epoch 00004: LearningRateScheduler reducing learning rate to 0.0004.\n",
            "Epoch 4/30\n",
            "195/195 [==============================] - 68s 348ms/step - loss: 0.3005 - acc: 0.9059 - lr: 4.0000e-04\n",
            "\n",
            "Epoch 00005: LearningRateScheduler reducing learning rate to 0.0005.\n",
            "Epoch 5/30\n",
            "195/195 [==============================] - 68s 349ms/step - loss: 0.2734 - acc: 0.9132 - lr: 5.0000e-04\n",
            "\n",
            "Epoch 00006: LearningRateScheduler reducing learning rate to 0.0006000000000000001.\n",
            "Epoch 6/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.2585 - acc: 0.9195 - lr: 6.0000e-04\n",
            "\n",
            "Epoch 00007: LearningRateScheduler reducing learning rate to 0.0007.\n",
            "Epoch 7/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.2357 - acc: 0.9247 - lr: 7.0000e-04\n",
            "\n",
            "Epoch 00008: LearningRateScheduler reducing learning rate to 0.0008.\n",
            "Epoch 8/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.2130 - acc: 0.9313 - lr: 8.0000e-04\n",
            "\n",
            "Epoch 00009: LearningRateScheduler reducing learning rate to 0.0009000000000000001.\n",
            "Epoch 9/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.1929 - acc: 0.9377 - lr: 9.0000e-04\n",
            "\n",
            "Epoch 00010: LearningRateScheduler reducing learning rate to 0.001.\n",
            "Epoch 10/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.1789 - acc: 0.9405 - lr: 0.0010\n",
            "\n",
            "Epoch 00011: LearningRateScheduler reducing learning rate to 0.0011.\n",
            "Epoch 11/30\n",
            "195/195 [==============================] - 67s 343ms/step - loss: 0.1794 - acc: 0.9370 - lr: 0.0011\n",
            "\n",
            "Epoch 00012: LearningRateScheduler reducing learning rate to 0.0012000000000000001.\n",
            "Epoch 12/30\n",
            "195/195 [==============================] - 69s 351ms/step - loss: 0.1617 - acc: 0.9439 - lr: 0.0012\n",
            "\n",
            "Epoch 00013: LearningRateScheduler reducing learning rate to 0.0013000000000000002.\n",
            "Epoch 13/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.1565 - acc: 0.9461 - lr: 0.0013\n",
            "\n",
            "Epoch 00014: LearningRateScheduler reducing learning rate to 0.0014.\n",
            "Epoch 14/30\n",
            "195/195 [==============================] - 69s 351ms/step - loss: 0.1512 - acc: 0.9464 - lr: 0.0014\n",
            "\n",
            "Epoch 00015: LearningRateScheduler reducing learning rate to 0.0015.\n",
            "Epoch 15/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.1435 - acc: 0.9506 - lr: 0.0015\n",
            "\n",
            "Epoch 00016: LearningRateScheduler reducing learning rate to 0.0016.\n",
            "Epoch 16/30\n",
            "195/195 [==============================] - 69s 352ms/step - loss: 0.1384 - acc: 0.9530 - lr: 0.0016\n",
            "\n",
            "Epoch 00017: LearningRateScheduler reducing learning rate to 0.0017000000000000001.\n",
            "Epoch 17/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.1323 - acc: 0.9535 - lr: 0.0017\n",
            "\n",
            "Epoch 00018: LearningRateScheduler reducing learning rate to 0.0018000000000000002.\n",
            "Epoch 18/30\n",
            "195/195 [==============================] - 67s 342ms/step - loss: 0.1343 - acc: 0.9538 - lr: 0.0018\n",
            "\n",
            "Epoch 00019: LearningRateScheduler reducing learning rate to 0.0019.\n",
            "Epoch 19/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.1291 - acc: 0.9565 - lr: 0.0019\n",
            "\n",
            "Epoch 00020: LearningRateScheduler reducing learning rate to 0.002.\n",
            "Epoch 20/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.1261 - acc: 0.9568 - lr: 0.0020\n",
            "\n",
            "Epoch 00021: LearningRateScheduler reducing learning rate to 0.002.\n",
            "Epoch 21/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.1196 - acc: 0.9600 - lr: 0.0020\n",
            "\n",
            "Epoch 00022: LearningRateScheduler reducing learning rate to 0.0009352484478226213.\n",
            "Epoch 22/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.0958 - acc: 0.9692 - lr: 9.3525e-04\n",
            "\n",
            "Epoch 00023: LearningRateScheduler reducing learning rate to 0.0004373448295773112.\n",
            "Epoch 23/30\n",
            "195/195 [==============================] - 69s 352ms/step - loss: 0.0730 - acc: 0.9787 - lr: 4.3734e-04\n",
            "\n",
            "Epoch 00024: LearningRateScheduler reducing learning rate to 0.00020451303651271454.\n",
            "Epoch 24/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.0654 - acc: 0.9817 - lr: 2.0451e-04\n",
            "\n",
            "Epoch 00025: LearningRateScheduler reducing learning rate to 9.563524997900369e-05.\n",
            "Epoch 25/30\n",
            "195/195 [==============================] - 68s 350ms/step - loss: 0.0600 - acc: 0.9843 - lr: 9.5635e-05\n",
            "\n",
            "Epoch 00026: LearningRateScheduler reducing learning rate to 4.4721359549995795e-05.\n",
            "Epoch 26/30\n",
            "195/195 [==============================] - 69s 352ms/step - loss: 0.0564 - acc: 0.9857 - lr: 4.4721e-05\n",
            "\n",
            "Epoch 00027: LearningRateScheduler reducing learning rate to 2.091279105182545e-05.\n",
            "Epoch 27/30\n",
            "195/195 [==============================] - 68s 351ms/step - loss: 0.0562 - acc: 0.9861 - lr: 2.0913e-05\n",
            "\n",
            "Epoch 00028: LearningRateScheduler reducing learning rate to 9.77932768542928e-06.\n",
            "Epoch 28/30\n",
            "195/195 [==============================] - 69s 351ms/step - loss: 0.0542 - acc: 0.9865 - lr: 9.7793e-06\n",
            "\n",
            "Epoch 00029: LearningRateScheduler reducing learning rate to 4.573050519273263e-06.\n",
            "Epoch 29/30\n",
            "195/195 [==============================] - 67s 343ms/step - loss: 0.0547 - acc: 0.9874 - lr: 4.5731e-06\n",
            "\n",
            "Epoch 00030: LearningRateScheduler reducing learning rate to 2.138469199982376e-06.\n",
            "Epoch 30/30\n",
            "195/195 [==============================] - 67s 344ms/step - loss: 0.0549 - acc: 0.9870 - lr: 2.1385e-06\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E_Bbk6c_hPud",
        "colab_type": "text"
      },
      "source": [
        "# Evaluation\n",
        "\n",
        "For evaluation, we prepare the dataset without shuffling:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dBkGzomBiWjG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "def to_model_ds(tfrecord_file, batch_size=batch_size, drop_remainder=True):\n",
        "  ds = tfrecord_to_dataset([tfrecord_file])\n",
        "  ds = ds.map(create_pad_example_fn(pad_len=max_seq_len))\n",
        "  ds = ds.batch(batch_size, drop_remainder=drop_remainder)\n",
        "  return ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "loylsQ6V6FmA",
        "colab_type": "text"
      },
      "source": [
        "and call evaluate:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-VUk8f8KYzeI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "a9ec01ee-9ab9-495d-9bea-15a2b75f888e"
      },
      "source": [
        "\n",
        "_, train_acc = model.evaluate(to_model_ds(train_tfrecord_file), steps=25000//batch_size)\n",
        "_, test_acc = model.evaluate(to_model_ds(test_tfrecord_file), steps=25000//batch_size)\n",
        "\n",
        "print(\"train acc\", train_acc)\n",
        "print(\" test acc\", test_acc)"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "195/195 [==============================] - 25s 126ms/step - loss: 0.0357 - acc: 0.9946\n",
            "195/195 [==============================] - 25s 128ms/step - loss: 0.2114 - acc: 0.9401\n",
            "train acc 0.9946314692497253\n",
            " test acc 0.9401442408561707\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T9A53X8V5V8G",
        "colab_type": "text"
      },
      "source": [
        "We could also create a new model, load the trained checkpoint and evaluate. \n",
        "Beware however that this would take ages on a CPU, so you might prefer running the lines below in a new GPU session:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BSqMu64oHzqy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "56d7368f-c845-4fe8-96ad-752214cd333d"
      },
      "source": [
        "%%time \n",
        "\n",
        "trained_ckpt_file = os.path.join(OUTPUT_DIR, 'checkpoints','trained','movie_reviews.ckpt')\n",
        "\n",
        "#model = create_model(max_seq_len, \n",
        "#                     adapter_size=adapter_size,\n",
        "#                     init_ckpt_file=trained_ckpt_file)\n",
        "\n",
        "#_, train_acc = model.evaluate(to_model_ds(train_tfrecord_file, drop_remainder=False))\n",
        "#_, test_acc = model.evaluate(to_model_ds(test_tfrecord_file, drop_remainder=False))\n",
        "\n",
        "print(\"train acc\", train_acc)\n",
        "print(\" test acc\", test_acc)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "train acc 0.9946314692497253\n",
            " test acc 0.9401442408561707\n",
            "CPU times: user 914 µs, sys: 0 ns, total: 914 µs\n",
            "Wall time: 766 µs\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mt-Rv-2hsW1j",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}